Simple Tasks and Symmetry

using the e3nn repository

tutorial by: Tess E. Smidt

code by:

DOI

@misc{mario_geiger_2019_3348277,
  author       = {Mario Geiger and
                  Tess Smidt and
                  Wouter Boomsma and
                  Maurice Weiler and
                  Michał Tyszkiewicz and
                  Jes Frellsen and
                  Benjamin K. Miller and
                  Josh Rackers},
  title        = {e3nn/e3nn: Point cloud support},
  month        = jul,
  year         = 2019,
  doi          = {10.5281/zenodo.3348277},
  url          = {https://doi.org/10.5281/zenodo.3348277}
}

There are some unintuitive consequences of using E(3) equivariant neural networks. The symmetry your output has to be equal to or higher than the symmetry of your input. The following 3 simple tasks are to help demonstrate this:

  • Task 1: Distort a rectangle to a square.
  • Task 2: Distort a square to a rectangle.
  • Task 3: Distort a square to a rectangle -- with symmetry breaking.

We will see that we can quickly do Task 1, but not Task 2. Only by using symmetry breaking in Task 3 are we able to distort a square into a rectangle.

In [1]:
import torch
from functools import partial
import numpy as np

import e3nn
import e3nn.o3 as o3
from e3nn.point.operations import Convolution
from e3nn.non_linearities import GatedBlock
from e3nn.kernel import Kernel
from e3nn.radial import CosineBasisModel
from e3nn.non_linearities import rescaled_act

import matplotlib.pyplot as plt
%matplotlib inline

from spherical import SphericalTensor

torch.set_default_dtype(torch.float64)
In [2]:
# Define out geometry
square = torch.tensor(
    [[0., 0., 0.], [1., 0., 0.], [1., 1., 0.], [0., 1., 0.]]
)
square -= square.mean(-2)
sx, sy = 0.5, 1.5
rectangle = square * torch.tensor([sx, sy, 0.])
rectangle -= rectangle.mean(-2)

N, _ = square.shape

markersize = 15

def plot_task(ax, start, finish, title, marker=None):
    ax.plot(torch.cat([start[:, 0], start[:, 0]]), 
            torch.cat([start[:, 1], start[:, 1]]), 'o-', 
            markersize=markersize + 5 if marker else markersize, 
            marker=marker if marker else 'o')
    ax.plot(torch.cat([finish[:, 0], finish[:, 0]]), 
            torch.cat([finish[:, 1], finish[:, 1]]), 'o-', markersize=markersize)
    for i in range(N):
        ax.arrow(start[i, 0], start[i, 1], 
                 finish[i, 0] - start[i, 0], 
                 finish[i, 1] - start[i, 1],
                 length_includes_head=True, head_width=0.05, facecolor="black", zorder=100)

    ax.set_title(title)
    ax.set_axis_off()

fig, axes = plt.subplots(1, 3, figsize=(15, 6))
plot_task(axes[0], rectangle, square, "Task 1: Rectangle to Square")
plot_task(axes[1], square, rectangle, "Task 2: Square to Rectangle")
plot_task(axes[2], square, rectangle, "Task 3: Square to Rectangle with Symmetry Breaking", "$\u2B2E$")

In these tasks, we want to move 4 points in one configuration to another configuration. The input to the network will be the initial geometry and features on that geometry. The output will be used to signify "displacement" of each point to the new configuration. We can represent displacement in a couple different ways. The simplest way is to represent a displacement as an L=1 vector, Rs=[(1, 1]]. However, to better illustrate the symmetry properties of the network, we instead are going to use a spherical harmonic signal or more specifically, the peak of the spherical harmonic signal, to signify the displacement of the original point.

First, we set up a very basic network that has the same representation list Rs = [(1, L) for L in range(5 + 1)] throughout the entire network. The input will be a spherical tensor with representation Rs and the output will also be a spherical tensor with representation Rs. We will interpret the output of the network as a spherical harmonic signal where the peak location will signify the desired displacement.

In [3]:
class Network(torch.nn.Module):
    def __init__(self, Rs, n_layers=3, max_radius=3.0, number_of_basis=3, radial_layers=3):
        super().__init__()
        self.Rs = Rs
        self.n_layers = n_layers
        self.L_max = max(L for m,L in Rs)
        
        sp = rescaled_act.Softplus(beta=5)
         
        Rs_geo = [(1, l) for l in range(self.L_max + 1)]
        Rs_centers = [(1, 0), (1, 1)]
        
        RadialModel = partial(CosineBasisModel, max_radius=max_radius,
                              number_of_basis=number_of_basis, h=100,
                              L=radial_layers, act=sp)

        K = partial(Kernel, RadialModel=RadialModel)
        C = partial(Convolution, K)

        def make_layer(Rs_in, Rs_out):
            act = GatedBlock(Rs_out, sp, rescaled_act.sigmoid)
            conv = Convolution(K, Rs_in, act.Rs_in)
            return torch.nn.ModuleList([conv, act])
            
        
        self.layers = torch.nn.ModuleList([
            make_layer(Rs, Rs)
            for i in range(n_layers - 1)
        ])

        self.lastlayer = torch.nn.ModuleList([
            Convolution(K, Rs, Rs)
        ])

    def forward(self, input, geometry):
        output = input
        batch, N, _ = geometry.shape  
        for conv, act in self.layers:
            output = conv(output.div(N ** 0.5), geometry)
            output = act(output)
            
        for layer in self.lastlayer:
            output = layer(output.div(N ** 0.5), geometry)
        return output

Task 1: Distort a rectangle to square.

In this task, our input is a four points in the shape of a rectangle with simple scalars (1.0) at each point. The task is to learn to displace the points to form a (more symmetric) square.

In [4]:
L_max = 5
Rs = [(1, L) for L in range(L_max + 1)]

model = Network(Rs)
print (model)

params = model.parameters()
optimizer = torch.optim.Adam(params, 1e-3)
loss_fn = torch.nn.MSELoss()
Network(
  (layers): ModuleList(
    (0): ModuleList(
      (0): Convolution(
        (kernel): Kernel (0,1,2,3,4,5 -> 0,1,2,3,4,5,5x0)
      )
      (1): GatedBlock()
    )
    (1): ModuleList(
      (0): Convolution(
        (kernel): Kernel (0,1,2,3,4,5 -> 0,1,2,3,4,5,5x0)
      )
      (1): GatedBlock()
    )
  )
  (lastlayer): ModuleList(
    (0): Convolution(
      (kernel): Kernel (0,1,2,3,4,5 -> 0,1,2,3,4,5)
    )
  )
)
In [5]:
input = torch.zeros(1, N, sum(2 * L + 1 for L in range(L_max + 1)))
input[:, :, 0] = 1.  # batch, point, channel

displacements = square - rectangle
N, _ = displacements.shape
projections = torch.stack([SphericalTensor.from_geometry(displacements[i], L_max).signal for i in range(N)])
In [6]:
iterations = 200
for i in range(iterations):
    output = model(input, rectangle.unsqueeze(0))
    loss = loss_fn(output, projections.unsqueeze(0))
    if i % 10 == 0:
        print(loss)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
tensor(0.0068, grad_fn=<MseLossBackward>)
tensor(0.0017, grad_fn=<MseLossBackward>)
tensor(0.0011, grad_fn=<MseLossBackward>)
tensor(0.0010, grad_fn=<MseLossBackward>)
tensor(0.0008, grad_fn=<MseLossBackward>)
tensor(0.0007, grad_fn=<MseLossBackward>)
tensor(0.0006, grad_fn=<MseLossBackward>)
tensor(0.0005, grad_fn=<MseLossBackward>)
tensor(0.0004, grad_fn=<MseLossBackward>)
tensor(0.0003, grad_fn=<MseLossBackward>)
tensor(0.0002, grad_fn=<MseLossBackward>)
tensor(0.0002, grad_fn=<MseLossBackward>)
tensor(0.0001, grad_fn=<MseLossBackward>)
tensor(0.0001, grad_fn=<MseLossBackward>)
tensor(6.9074e-05, grad_fn=<MseLossBackward>)
tensor(4.5870e-05, grad_fn=<MseLossBackward>)
tensor(2.9675e-05, grad_fn=<MseLossBackward>)
tensor(1.8917e-05, grad_fn=<MseLossBackward>)
tensor(1.1897e-05, grad_fn=<MseLossBackward>)
tensor(7.2867e-06, grad_fn=<MseLossBackward>)
In [7]:
# Plot spherical harmonic projections
import plotly
from plotly.subplots import make_subplots
import plotly.graph_objects as go
In [8]:
def plot_output(start, finish, output, start_label, finish_label):
    rows, cols = 1, 1
    specs = [[{'is_3d': True} for i in range(cols)]
             for j in range(rows)]
    fig = make_subplots(rows=rows, cols=cols, specs=specs)
    fig.add_trace(go.Scatter3d(x=start[:, 0], y=start[:, 1], z=start[:, 2], mode="markers", name=start_label))
    fig.add_trace(go.Scatter3d(x=finish[:, 0], y=finish[:, 1], z=finish[:, 2], mode="markers", name=finish_label))
    for i in range(N):
        trace = SphericalTensor(output[0][i].detach(), Rs).plot(center=start[i])
        trace.showscale = False
        fig.add_trace(trace, 1, 1)
    return fig
In [9]:
output = model(input, rectangle.unsqueeze(0))
fig = plot_output(rectangle, square, output, "Rectangle", "Square")
fig.update_layout(scene_aspectmode='data')
fig.show()

And let's check that it's equivariant

In [10]:
angles = torch.rand(3) * torch.tensor([np.pi, 2 * np.pi, np.pi])
rot = o3.rot(*angles)
rot_rectangle = torch.einsum('xy,jy->jx', (rot, rectangle))
rot_square = torch.einsum('xy,jy->jx', (rot, square))
output = model(input, rot_rectangle.unsqueeze(0))
fig = plot_output(rot_rectangle, rot_square, output, "Rectangle", "Square")
fig.update_layout(scene_aspectmode='data')
fig.show()

Task 2: Now the reverse! Distort a square to rectangle.

In this task, our input is a four points in the shape of a square with simple scalars (1.0) at each point. The task is to learn to displace the points to form a (less symmetric) rectangle. Can the network learn this task?

In [11]:
model = Network(Rs)

params = model.parameters()
optimizer = torch.optim.Adam(params, 1e-3)
loss_fn = torch.nn.MSELoss()
In [12]:
input = torch.zeros(1, N, sum(2 * L + 1 for L in range(L_max + 1)))
input[:, :, 0] = 1.  # batch, point, channel



displacements = rectangle - square
N, _ = displacements.shape
projections = torch.stack([SphericalTensor.from_geometry(displacements[i], L_max).signal for i in range(N)])
In [13]:
iterations = 100
for i in range(iterations):
    output = model(input, square.unsqueeze(0))
    loss = loss_fn(output, projections.unsqueeze(0))
    if i % 10 == 0:
        print(loss)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
tensor(0.0138, grad_fn=<MseLossBackward>)
tensor(0.0017, grad_fn=<MseLossBackward>)
tensor(0.0012, grad_fn=<MseLossBackward>)
tensor(0.0010, grad_fn=<MseLossBackward>)
tensor(0.0009, grad_fn=<MseLossBackward>)
tensor(0.0008, grad_fn=<MseLossBackward>)
tensor(0.0007, grad_fn=<MseLossBackward>)
tensor(0.0007, grad_fn=<MseLossBackward>)
tensor(0.0007, grad_fn=<MseLossBackward>)
tensor(0.0007, grad_fn=<MseLossBackward>)

Hmm... seems to get stuck. Let's try more iterations.

In [14]:
iterations = 100
for i in range(iterations):
    output = model(input, square.unsqueeze(0))
    loss = loss_fn(output, projections.unsqueeze(0))
    if i % 10 == 0:
        print(loss)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
tensor(0.0007, grad_fn=<MseLossBackward>)
tensor(0.0007, grad_fn=<MseLossBackward>)
tensor(0.0007, grad_fn=<MseLossBackward>)
tensor(0.0007, grad_fn=<MseLossBackward>)
tensor(0.0007, grad_fn=<MseLossBackward>)
tensor(0.0007, grad_fn=<MseLossBackward>)
tensor(0.0007, grad_fn=<MseLossBackward>)
tensor(0.0007, grad_fn=<MseLossBackward>)
tensor(0.0007, grad_fn=<MseLossBackward>)
tensor(0.0007, grad_fn=<MseLossBackward>)

It's stuck. What's going on?

In [15]:
fig = plot_output(square, rectangle, output, "Square", "Rectangle")
fig.update_layout(scene_aspectmode='data')
fig.show()

The symmetry of the output must be higher or equal to the symmetry of the input!

To be able to do this task, you need to give the network more information -- information that breaks the symmetry to that of the desired output.

Task 3: Fixing Task 2. Distort a square into a rectangle -- now, with symmetry breaking!

In this task, our input is a four points in the shape of a square with simple scalars (1.0) AND a contribution for the $x^2 - y^2$ feature at each point. The task is to learn to displace the points to form a (less symmetric) rectangle. Can the network learn this task?

In [16]:
model = Network(Rs)

params = model.parameters()
optimizer = torch.optim.Adam(params, 1e-3)
loss_fn = torch.nn.MSELoss()
In [17]:
input = torch.zeros(1, N, sum(2 * L + 1 for L in range(L_max + 1)))
input[:, :, 0] = 1.  # batch, point, channel
# Breaking x and y symmetry with x^2 - y^2 component
input[:, :, 8] = 0.1  # x^2 - y^2

displacements = rectangle - square
N, _ = displacements.shape
projections = torch.stack([SphericalTensor.from_geometry(displacements[i], L_max).signal for i in range(N)])
In [18]:
iterations = 200
for i in range(iterations):
    output = model(input, square.unsqueeze(0))
    loss = loss_fn(output, projections.unsqueeze(0))
    if i % 10 == 0:
        print(loss)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
tensor(0.0108, grad_fn=<MseLossBackward>)
tensor(0.0021, grad_fn=<MseLossBackward>)
tensor(0.0013, grad_fn=<MseLossBackward>)
tensor(0.0011, grad_fn=<MseLossBackward>)
tensor(0.0010, grad_fn=<MseLossBackward>)
tensor(0.0008, grad_fn=<MseLossBackward>)
tensor(0.0007, grad_fn=<MseLossBackward>)
tensor(0.0006, grad_fn=<MseLossBackward>)
tensor(0.0005, grad_fn=<MseLossBackward>)
tensor(0.0004, grad_fn=<MseLossBackward>)
tensor(0.0003, grad_fn=<MseLossBackward>)
tensor(0.0002, grad_fn=<MseLossBackward>)
tensor(0.0002, grad_fn=<MseLossBackward>)
tensor(0.0002, grad_fn=<MseLossBackward>)
tensor(0.0001, grad_fn=<MseLossBackward>)
tensor(0.0001, grad_fn=<MseLossBackward>)
tensor(9.0100e-05, grad_fn=<MseLossBackward>)
tensor(7.2446e-05, grad_fn=<MseLossBackward>)
tensor(5.6702e-05, grad_fn=<MseLossBackward>)
tensor(4.2933e-05, grad_fn=<MseLossBackward>)
In [19]:
fig = plot_output(square, rectangle, output, "Square", "Rectangle")
fig.update_layout(scene_aspectmode='data')
fig.show()

What is $x^2 - y^2$ the term doing? It's breaking the symmetry along the $\hat{x}$ and $\hat{y}$ directions.

Notice how the shape below is an ellisoid elongated in the y direction and squished in the x. This isn't the only pertubation we could've added, but it is the most symmetric.

In [20]:
rows, cols = 1, 1
specs = [[{'is_3d': True} for i in range(cols)]
         for j in range(rows)]
fig = make_subplots(rows=rows, cols=cols, specs=specs)

L_max = 5
Rs = [(1, L) for L in range(L_max + 1)]
sum_Ls = sum(2 * L + 1 for mult, L in Rs) 

# Random spherical tensor up to L_Max
signal = torch.zeros(sum_Ls)
signal[0] = 1
# Breaking x and y symmetry with x^2 - y^2
signal[8] = -0.1

sphten = SphericalTensor(signal, Rs)

trace = sphten.plot(relu=False, n=60)
fig.add_trace(trace, row=1, col=1)
fig.show()
In [ ]: